/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.util;



import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.CoderException;

import com.bff.gaia.unified.sdk.util.common.ReflectHelpers;

import org.xerial.snappy.SnappyInputStream;

import org.xerial.snappy.SnappyOutputStream;



import java.io.*;

import java.lang.reflect.Proxy;

import java.util.Arrays;



import static com.bff.gaia.unified.sdk.util.CoderUtils.decodeFromByteArray;

import static com.bff.gaia.unified.sdk.util.CoderUtils.encodeToByteArray;

import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkState;



/** Utilities for working with Serializables. */

public class SerializableUtils {

  /**

   * Serializes the argument into an array of bytes, and returns it.

   *

   * @throws IllegalArgumentException if there are errors when serializing

   */

  public static byte[] serializeToByteArray(Serializable value) {

    try {

      ByteArrayOutputStream buffer = new ByteArrayOutputStream();

      try (ObjectOutputStream oos = new ObjectOutputStream(new SnappyOutputStream(buffer))) {

        oos.writeObject(value);

      }

      return buffer.toByteArray();

    } catch (IOException exn) {

      throw new IllegalArgumentException("unable to serialize " + value, exn);

    }

  }



  /**

   * Deserializes an object from the given array of bytes, e.g., as serialized using {@link

   * #serializeToByteArray}, and returns it.

   *

   * @throws IllegalArgumentException if there are errors when deserializing, using the provided

   *     description to identify what was being deserialized

   */

  public static Object deserializeFromByteArray(byte[] encodedValue, String description) {

    try {

      try (ObjectInputStream ois =

          new ContextualObjectInputStream(

              new SnappyInputStream(new ByteArrayInputStream(encodedValue)))) {

        return ois.readObject();

      }

    } catch (IOException | ClassNotFoundException exn) {

      throw new IllegalArgumentException("unable to deserialize " + description, exn);

    }

  }



  public static <T extends Serializable> T ensureSerializable(T value) {

    return clone(value);

  }



  public static <T extends Serializable> T clone(T value) {

    final Thread thread = Thread.currentThread();

    final ClassLoader tccl = thread.getContextClassLoader();

    ClassLoader loader = tccl;

    try {

      if (tccl.loadClass(value.getClass().getName()) != value.getClass()) {

        loader = value.getClass().getClassLoader();

      }

    } catch (final NoClassDefFoundError | ClassNotFoundException e) {

      loader = value.getClass().getClassLoader();

    }

    if (loader == null) {

      loader = tccl; // will likely fail but the best we can do

    }

    thread.setContextClassLoader(loader);

    @SuppressWarnings("unchecked")

    final T copy;

    try {

      copy = (T) deserializeFromByteArray(serializeToByteArray(value), value.toString());

    } finally {

      thread.setContextClassLoader(tccl);

    }

    return copy;

  }



  /**

   * Serializes a Coder and verifies that it can be correctly deserialized.

   *

   * <p>Throws a RuntimeException if serialized Coder cannot be deserialized, or if the deserialized

   * instance is not equal to the original.

   *

   * @return the deserialized Coder

   */

  public static Coder<?> ensureSerializable(Coder<?> coder) {

    // Make sure that Coders are java serializable as well since

    // they are regularly captured within DoFn's.

    Coder<?> copy = (Coder<?>) ensureSerializable((Serializable) coder);



    checkState(

        coder.equals(copy),

        "Coder not equal to original after serialization, indicating that the Coder may not "

            + "implement serialization correctly.  Before: %s, after: %s",

        coder,

        copy);



    return copy;

  }



  /**

   * Serializes an arbitrary T with the given {@code Coder<T>} and verifies that it can be correctly

   * deserialized.

   */

  public static <T> T ensureSerializableByCoder(Coder<T> coder, T value, String errorContext) {

    byte[] encodedValue;

    try {

      encodedValue = CoderUtils.encodeToByteArray(coder, value);

    } catch (CoderException exn) {

      // TODO: Put in better element printing:

      // truncate if too long.

      throw new IllegalArgumentException(

          errorContext + ": unable to encode value " + value + " using " + coder, exn);

    }

    try {

      return CoderUtils.decodeFromByteArray(coder, encodedValue);

    } catch (CoderException exn) {

      // TODO: Put in better encoded byte array printing:

      // use printable chars with escapes instead of codes, and

      // truncate if too long.

      throw new IllegalArgumentException(

          errorContext

              + ": unable to decode "

              + Arrays.toString(encodedValue)

              + ", encoding of value "

              + value

              + ", using "

              + coder,

          exn);

    }

  }



  private static final class ContextualObjectInputStream extends ObjectInputStream {

    private ContextualObjectInputStream(final InputStream in) throws IOException {

      super(in);

    }



    @Override

    protected Class<?> resolveClass(final ObjectStreamClass classDesc)

        throws IOException, ClassNotFoundException {

      // note: staying aligned on JVM default but can need class filtering here to avoid 0day issue

      final String n = classDesc.getName();

      final ClassLoader classloader = ReflectHelpers.findClassLoader();

      try {

        return Class.forName(n, false, classloader);

      } catch (final ClassNotFoundException e) {

        return super.resolveClass(classDesc);

      }

    }



    @Override

    protected Class resolveProxyClass(final String[] interfaces)

        throws IOException, ClassNotFoundException {

      final ClassLoader classloader = ReflectHelpers.findClassLoader();



      final Class[] cinterfaces = new Class[interfaces.length];

      for (int i = 0; i < interfaces.length; i++) {

        cinterfaces[i] = classloader.loadClass(interfaces[i]);

      }



      try {

        return Proxy.getProxyClass(classloader, cinterfaces);

      } catch (final IllegalArgumentException e) {

        throw new ClassNotFoundException(null, e);

      }

    }

  }

}